{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Use one of the following command to start using pyalink:\n",
      "使用以下一条命令来开始使用 pyalink：\n",
      " - useLocalEnv(parallelism, flinkHome=None, config=None)\n",
      " - useRemoteEnv(host, port, parallelism, flinkHome=None, localIp=\"localhost\", config=None)\n",
      "Call resetEnv() to reset environment and switch to another.\n",
      "使用 resetEnv() 来重置运行环境，并切换到另一个。\n",
      "\n",
      "JVM listening on 127.0.0.1:57514\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "JavaObject id=o6"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyalink.alink import *\n",
    "resetEnv()\n",
    "useLocalEnv(1, config=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 手写数字识别\n",
    "- 使用Softmax 训练模型\n",
    "- 使用模型预测\n",
    "- 评估预测结果\n",
    "\n",
    "# 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "URL = \"https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_dense.csv\"\n",
    "SCHEMA_STR = \"label bigint,bitmap string\"\n",
    "mnist_data = CsvSourceBatchOp() \\\n",
    "    .setFilePath(URL) \\\n",
    "    .setSchemaStr(SCHEMA_STR)\\\n",
    "    .setFieldDelimiter(\";\")\n",
    "spliter = SplitBatchOp().setFraction(0.8)\n",
    "train = spliter.linkFrom(mnist_data)\n",
    "test = spliter.getSideOutput(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练 + 预测 + 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "softmax = Softmax().setVectorCol(\"bitmap\").setLabelCol(\"label\") \\\n",
    "        .setPredictionCol(\"pred\").setPredictionDetailCol(\"detail\") \\\n",
    "        .setEpsilon(0.0001).setMaxIter(200)\n",
    "model = softmax.fit(train)\n",
    "res = model.transform(test)\n",
    "\n",
    "evaluation = EvalMultiClassBatchOp().setLabelCol(\"label\").setPredictionCol(\"pred\")\n",
    "metrics = evaluation.linkFrom(res).collectMetrics()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 打印结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ConfusionMatrix: [[170, 3, 5, 0, 1, 7, 2, 2, 1, 0], [2, 154, 2, 1, 14, 3, 6, 9, 0, 2], [9, 3, 174, 0, 3, 3, 3, 3, 0, 0], [0, 0, 1, 162, 5, 4, 2, 6, 0, 7], [5, 9, 2, 5, 160, 1, 8, 1, 0, 0], [11, 4, 2, 0, 4, 187, 1, 2, 1, 1], [2, 5, 2, 2, 6, 1, 170, 4, 1, 0], [0, 2, 8, 4, 2, 4, 8, 180, 6, 1], [1, 3, 3, 1, 3, 1, 3, 3, 209, 0], [2, 2, 2, 0, 3, 1, 1, 2, 0, 179]]\n",
      "LabelArray: ['9', '8', '7', '6', '5', '4', '3', '2', '1', '0']\n",
      "LogLoss: None\n",
      "TotalSamples: 2000\n",
      "ActualLabelProportion: [0.101, 0.0925, 0.1005, 0.0875, 0.1005, 0.106, 0.102, 0.106, 0.109, 0.095]\n",
      "ActualLabelFrequency: [202, 185, 201, 175, 201, 212, 204, 212, 218, 190]\n",
      "Accuracy: 0.8725\n",
      "Kappa: 0.858283141946106\n"
     ]
    }
   ],
   "source": [
    "print(\"ConfusionMatrix:\", metrics.getConfusionMatrix())\n",
    "print(\"LabelArray:\", metrics.getLabelArray())\n",
    "print(\"LogLoss:\", metrics.getLogLoss())\n",
    "print(\"TotalSamples:\", metrics.getTotalSamples())\n",
    "print(\"ActualLabelProportion:\", metrics.getActualLabelProportion())\n",
    "print(\"ActualLabelFrequency:\", metrics.getActualLabelFrequency())\n",
    "print(\"Accuracy:\", metrics.getAccuracy())\n",
    "print(\"Kappa:\", metrics.getKappa())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
